{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "51290d2a",
   "metadata": {},
   "source": [
    "# Automatic Speech Recogntion with Hugging Face's Transformers & Amazon SageMaker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dfca127",
   "metadata": {},
   "source": [
    "Transformer models are changing the world of machine learning, starting with natural language processing, and now, with audio and computer vision.  Hugging Face's mission is to democratize good machine learning and give anyone the opportunity to use these new state-of-the-art machine learning models. \n",
    "Together with Amazon SageMaker and AWS have we been working on extending the functionalities of the Hugging Face Inference DLC and the Python SageMaker SDK to make it easier to use speech and vision models together with `transformers`. \n",
    "You can now use the Hugging Face Inference DLC to do [automatic speech recognition](https://huggingface.co/tasks/automatic-speech-recognition) using MetaAIs [wav2vec2](https://arxiv.org/abs/2006.11477) model or Microsofts [WavLM](https://arxiv.org/abs/2110.13900) or use NVIDIAs [SegFormer](https://arxiv.org/abs/2105.15203) for [semantic segmentation](https://huggingface.co/tasks/image-segmentation).\n",
    "\n",
    "\n",
    "This guide will walk you through how to do [automatic speech recognition](https://huggingface.co/tasks/automatic-speech-recognition) using [wav2veec2](https://huggingface.co/facebook/wav2vec2-base-960h) and new `DataSerializer`.\n",
    "\n",
    "![automatic_speech_recognition](imgs/automatic_speech_recognition.png)\n",
    "\n",
    "\n",
    "In this example you will learn how to: \n",
    "\n",
    "1. Setup a development Environment and permissions for deploying Amazon SageMaker Inference Endpoints.\n",
    "2. Deploy a wav2vec2 model to Amazon SageMaker for automatic speech recogntion\n",
    "3. Send requests to the endpoint to do speech recognition.\n",
    "   \n",
    "Let's get started! 🚀\n",
    "\n",
    "---\n",
    "\n",
    "*If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d66bd6bb",
   "metadata": {},
   "source": [
    "## 1. Setup a development Environment and permissions for deploying Amazon SageMaker Inference Endpoints.\n",
    "\n",
    "Setting up the development environment and permissions needs to be done for the automatic-speech-recognition example and the semantic-segmentation example. First we update the `sagemaker` SDK to make sure we have new `DataSerializer`. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad20442",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e4386d9",
   "metadata": {},
   "source": [
    "After we have update the SDK we can set the permissions.\n",
    "\n",
    "_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1c22e8d5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role\n",
      "sagemaker bucket: sagemaker-us-east-1-558105141721\n",
      "sagemaker session region: us-east-1\n"
     ]
    }
   ],
   "source": [
    "import sagemaker\n",
    "import boto3\n",
    "sess = sagemaker.Session()\n",
    "# sagemaker session bucket -> used for uploading data, models and logs\n",
    "# sagemaker will automatically create this bucket if it not exists\n",
    "sagemaker_session_bucket=None\n",
    "if sagemaker_session_bucket is None and sess is not None:\n",
    "    # set to default bucket if a bucket name is not given\n",
    "    sagemaker_session_bucket = sess.default_bucket()\n",
    "\n",
    "try:\n",
    "    role = sagemaker.get_execution_role()\n",
    "except ValueError:\n",
    "    iam = boto3.client('iam')\n",
    "    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
    "\n",
    "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
    "\n",
    "print(f\"sagemaker role arn: {role}\")\n",
    "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
    "print(f\"sagemaker session region: {sess.boto_region_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f683db52",
   "metadata": {},
   "source": [
    "## 2. Deploy a wav2vec2 model to Amazon SageMaker for automatic speech recogntion\n",
    "\n",
    "\n",
    "Automatic Speech Recognition (ASR), also known as Speech to Text (STT), is the task of transcribing a given audio to text. It has many applications, such as voice user interfaces.\n",
    "\n",
    "We use the [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) model running our recognition endpoint. This model is a fine-tune checkpoint of [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) pretrained and fine-tuned on 960 hours of Librispeech on 16kHz sampled speech audio achieving 1.8/3.3 WER on the clean/other test sets.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a527872b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.huggingface.model import HuggingFaceModel\n",
    "from sagemaker.serializers import DataSerializer\n",
    "\n",
    "# Hub Model configuration. <https://huggingface.co/models>\n",
    "hub = {\n",
    "    'HF_MODEL_ID':'facebook/wav2vec2-base-960h',\n",
    "    'HF_TASK':'automatic-speech-recognition',\n",
    "}\n",
    "\n",
    "# create Hugging Face Model Class\n",
    "huggingface_model = HuggingFaceModel(\n",
    "   env=hub,                      # configuration for loading model from Hub\n",
    "   role=role,                    # iam role with permissions to create an Endpoint\n",
    "   transformers_version=\"4.26\",  # transformers version used\n",
    "   pytorch_version=\"1.13\",        # pytorch version used\n",
    "   py_version='py39',            # python version used\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "057ccd37",
   "metadata": {},
   "source": [
    "Before we are able to deploy our `HuggingFaceModel` class we need to create a new serializer, which supports our audio data. The Serializer are used in Predictor and in the `predict` method to serializer our data to a specific `mime-type`, which send to the endpoint. The default serialzier for the HuggingFacePredcitor is a JSNON serializer, but since we are not going to send text data to the endpoint we will use the DataSerializer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1681dd7e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------!"
     ]
    }
   ],
   "source": [
    "# create a serializer for the data\n",
    "audio_serializer = DataSerializer(content_type='audio/x-audio') # using x-audio to support multiple audio formats\n",
    "\n",
    "# deploy model to SageMaker Inference\n",
    "predictor = huggingface_model.deploy(\n",
    "\tinitial_instance_count=1, # number of instances\n",
    "\tinstance_type='ml.g4dn.xlarge', # ec2 instance type\n",
    "  serializer=audio_serializer, # serializer for our audio data.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b3812f",
   "metadata": {},
   "source": [
    "## 3. Send requests to the endpoint to do speech recognition.\n",
    "\n",
    "The `.deploy()` returns an `HuggingFacePredictor` object with our `DataSeriliazer` which can be used to request inference. This `HuggingFacePredictor` makes it easy to send requests to your endpoint and get the results back.\n",
    "\n",
    "We will use 3 different methods to send requests to the endpoint:\n",
    "\n",
    "a. Provide a audio file via path to the predictor  \n",
    "b. Provide binary audio data object to the predictor  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed57d108",
   "metadata": {},
   "source": [
    "### a. Provide a audio file via path to the predictor\n",
    "\n",
    "Using a audio file as input is easy as easy as providing the path to its location. The `DataSerializer` will then read it and send the bytes to the endpoint. \n",
    "\n",
    "We can use a `libirispeech` sample hosted on huggingface.co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b176897",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-03-21 08:27:32--  https://cdn-media.huggingface.co/speech_samples/sample1.flac\n",
      "Resolving cdn-media.huggingface.co (cdn-media.huggingface.co)... 18.160.46.12, 18.160.46.81, 18.160.46.78, ...\n",
      "Connecting to cdn-media.huggingface.co (cdn-media.huggingface.co)|18.160.46.12|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 282378 (276K) [audio/flac]\n",
      "Saving to: ‘sample1.flac’\n",
      "\n",
      "100%[======================================>] 282,378     --.-K/s   in 0.004s  \n",
      "\n",
      "2023-03-21 08:27:32 (69.1 MB/s) - ‘sample1.flac’ saved [282378/282378]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://cdn-media.huggingface.co/speech_samples/sample1.flac"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "083f008e",
   "metadata": {},
   "source": [
    "To send a request with provide our path to the audio file we can use the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "51c5366d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'text': \"GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOL ROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS\"}\n"
     ]
    }
   ],
   "source": [
    "audio_path = \"sample1.flac\"\n",
    "\n",
    "res = predictor.predict(data=audio_path)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "771966ba",
   "metadata": {},
   "source": [
    "### b. Provide binary audio data object to the predictor\n",
    "\n",
    "Instead of providing a path to the audio file we can also directy provide the bytes of it reading the file in python.\n",
    "\n",
    "\n",
    "_make sure `sample1.flac` is in the directory_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "51c5366b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'text': \"GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP AUDIENCES IN DRAUGHTY SCHOOL ROOMS DAY AFTER DAY FOR A FORTNIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS\"}\n"
     ]
    }
   ],
   "source": [
    "audio_path = \"sample1.flac\"\n",
    "\n",
    "with open(audio_path, \"rb\") as data_file:\n",
    "  audio_data = data_file.read()\n",
    "  res = predictor.predict(data=audio_data)\n",
    "  print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f9817a1",
   "metadata": {},
   "source": [
    "## Clean up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1e6fb7b8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictor.delete_model()\n",
    "predictor.delete_endpoint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f846e812",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9"
  },
  "kernelspec": {
   "display_name": "conda_pytorch_p39",
   "language": "python",
   "name": "conda_pytorch_p39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
